from typing import List, Dict, Optional
from pathlib import Path
import json
import random
import time
from tqdm import tqdm
from openai import OpenAI
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class Config:
    SAMPLE_SIZE = 300
    API_RATE_LIMIT = 2
    MAX_RETRIES = 3
    
    BASE_PATH = Path("/process_COT/letter")
    INPUT_FILE = BASE_PATH / "train.json"
    OUTPUT_FILE = BASE_PATH / "reasoning_output_letter300_correct.txt"
    PROGRESS_FILE = BASE_PATH / "progress.json"
    
    API_KEY = ""
    BASE_URL = ""
    MODEL_NAME = "gpt-4o"

class QuestionProcessor:
    def __init__(self):
        self.client = OpenAI(
            api_key=Config.API_KEY,
            base_url=Config.BASE_URL
        )
        self.processed_questions = self._load_progress()

    def _load_progress(self) -> set:
        if Config.PROGRESS_FILE.exists():
            with open(Config.PROGRESS_FILE, 'r') as f:
                return set(json.load(f))
        return set()

    def _save_progress(self, question: str):
        self.processed_questions.add(question)
        with open(Config.PROGRESS_FILE, 'w') as f:
            json.dump(list(self.processed_questions), f)

    def get_completion(self, prompt: str, retries: int = Config.MAX_RETRIES) -> Optional[str]:
        for attempt in range(retries):
            try:
                response = self.client.chat.completions.create(
                    model=Config.MODEL_NAME,
                    messages=[
                        {"role": "system", "content": "You are a helpful assistant that provides step-by-step reasoning."},
                        {"role": "user", "content": prompt}
                    ]
                )
                return response.choices[0].message.content.strip()
            except Exception as e:
                logging.warning(f"API call failed (attempt {attempt + 1}/{retries}): {str(e)}")
                if attempt < retries - 1:
                    time.sleep(2 ** attempt)
                else:
                    logging.error(f"API call final failure: {str(e)}")
                    return None

    def format_question(self, qa_item: Dict) -> str:
        return f"Question: {qa_item['question']}\nReference Answer: {qa_item['answer']}"

    def generate_prompt(self, formatted_q: str) -> str:
        return f"""Please analyze this question and provide step-by-step reasoning to find the final answer.
{formatted_q}

Format your response exactly like this:
{formatted_q}
The last letter of "Word1" is "x".
The last letter of "Word2" is "y".
[Continue for each word...]
Concatenating them is "xyz".
The answer is "xyz"."""

    def format_response(self, question: str, response: str) -> str:
        lines = []
        
        question_lines = question.strip().split('\n')
        lines.extend(line.strip() for line in question_lines if not line.startswith('Reference Answer:'))
        
        response_lines = response.strip().split('\n')
        for line in response_lines:
            line = line.strip()
            if not line or line.startswith('Question:') or line.startswith('Reference Answer:'):
                continue
            lines.append(line)
        
        return '\n'.join(lines)

    def process_questions(self, questions: List[Dict]):
        Config.OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
        
        selected_questions = questions[:Config.SAMPLE_SIZE]
        
        with open(Config.OUTPUT_FILE, 'a', encoding='utf-8') as f:
            for qa in tqdm(selected_questions, desc=f"Processing first {Config.SAMPLE_SIZE} questions"):
                if qa['question'] in self.processed_questions:
                    logging.info(f"Skipping already processed question: {qa['question']}")
                    continue
                
                formatted_q = self.format_question(qa)
                prompt = self.generate_prompt(formatted_q)
                
                if response := self.get_completion(prompt):
                    formatted_response = self.format_response(formatted_q, response)
                    f.write(f"{formatted_response}\n\n")
                    self._save_progress(qa['question'])
                    
                time.sleep(Config.API_RATE_LIMIT)

def main():
    try:
        with open(Config.INPUT_FILE, 'r', encoding='utf-8') as f:
            try:
                questions = json.load(f)
                logging.info(f"loaded {len(questions)} questions")
            except json.JSONDecodeError as e:
                logging.error(f"JSON parsing error: {str(e)}")
                return
        
        processor = QuestionProcessor()
        processor.process_questions(questions)
        
        logging.info("Processing completed!")
        
    except Exception as e:
        logging.error(f"Program error: {str(e)}")

if __name__ == "__main__":
    main()